#include <assert.h>
#include <stdio.h>
#include <stdlib.h>

#include "../src/em.c"

typedef struct _EmTestDatum EmTestDatum;
struct _EmTestDatum 
{ 
  struct
    {
      float    alpha[4 * 5 * 5];
      float    logbeta;
      float    gamma[5 * 5];
      float    pz[5];
      float    priorz[5];
      Rating   ratings[8];
      float    alphastddev;
      float    logbetamean;
      float    logbetastddev;
      float    gammamean;
      float    gammastddev;
      int      clamp;
    }                   inputs;

  struct
    {
      float    newq;
      float    logbeta;
      float    pz[5];
    }                   desired_output;
};

static void
test_em (void)
{
  EmTestDatum data[] = 
    {{{{-1., -0.351049, -0.180766, -0.924837, -0.126337, -0.856209, -1., -0.210468, -0.296039, -0.846668, -0.66282, -0.347971, -1., -0.571906, -0.110637, -0.714098, -0.542027, -0.757586, -1., -0.777415, -0.0743018, -0.736379, -0.505554, -0.91589, -1., -1., -0.723796, -0.0165743, -0.438353, -0.885353, -0.372747, -1., -0.835809, -0.513516, -0.759016, -0.516537, -0.625341, -1., -0.217477, -0.912348, -0.853717, -0.277369, -0.64557, -1., -0.80171, -0.139619, -0.735343, -0.887985, -0.0242956, -1., -1., -0.0653171, -0.998963, -0.38243, -0.108406, -0.341522, -1., -0.982389, -0.944078, -0.223053, -0.968775, -0.146581, -1., -0.430562, -0.464037, -0.452237, -0.52124, -0.213085, -1., -0.551689, -0.59852, -0.24387, -0.567515, -0.749979, -1., -1., -0.458901, -0.508528, -0.67953, -0.725683, -0.393584, -1., -0.509564, -0.2971, -0.617278, -0.0520625, -0.527175, -1., -0.353022, -0.394225, -0.0832877, -0.380594, -0.922461, -1., -0.930187, -0.63105, -0.859355, -0.709376, -0.378498, -1.},  0., {1., 0.96747, 0.384516, 0.858139, 0.371481, 0.426371, 1.,  0.893043, 0.537669, 0.0971647, 0.819955, 0.402607, 1., 0.834769,  0.714442, 0.872018, 0.929782, 0.187791, 1., 0.108667, 0.955305,  0.310377, 0.110251, 0.0388542, 1.}, {0.586356, 0.169731, 0.819627,  0.417352, 0.618886}, {0.785216, 0.961488, 0.0458707, 0.192515,  0.892172}, {{1, 1}, {1, 0}, {3, 0}, {3, 0}, {0, 2}, {2, 2}, {2,  4}, {0, 3}}, 0.423819, 0.948706, 0.37256, 0.489565,  0.58905, -1}, {-12.030650781158147293,  0.77582364331042135959, {0.883692, 0.0067689, 0.0025097, 0.0010604,  0.105969}}}, {{{-1., -0.499458, -0.440218, -0.59874, -0.874403, -0.454763, -1., -0.750594, -0.708991, -0.913257, -0.0411189, -0.920326, -1., -0.528618, -0.330609, -0.660005, -0.705541, -0.490106, -1., -0.37648, -0.85252, -0.597714, -0.913925, -0.325186, -1., -1., -0.225079, -0.0872784, -0.502976, -0.55945, -0.725621, -1., -0.647061, -0.904236, -0.685047, -0.270858, -0.896467, -1., -0.195244, -0.771789, -0.229739, -0.976141, -0.666626, -1., -0.44118, -0.569734, -0.2706, -0.17652, -0.0646996, -1., -1., -0.717215, -0.672886, -0.262594, -0.739513, -0.492136, -1., -0.585608, -0.759618, -0.180064, -0.766514, -0.938547, -1., -0.855383, -0.495017, -0.495656, -0.0420802, -0.660139, -1., -0.723228, -0.265917, -0.0659394, -0.993513, -0.282048, -1., -1., -0.696183, -0.79534, -0.816993, -0.217348, -0.978968, -1., -0.122454, -0.554399, -0.477835, -0.486833, -0.536846, -1., -0.79478, -0.297771, -0.720318, -0.5983, -0.939398, -1., -0.802754, -0.224662, -0.556219, -0.279259, -0.0795265, -1.},  0., {1., 0.0412554, 0.50972, 0.714254, 0.202521, 0.737438, 1.,  0.30506, 0.531247, 0.41987, 0.716407, 0.427514, 1., 0.0856455,  0.897705, 0.203239, 0.96436, 0.880426, 1., 0.195476, 0.923558,  0.56266, 0.819823, 0.998231, 1.}, {0.14822, 0.118879, 0.0990823,  0.0777571, 0.106964}, {0.609159, 0.384829, 0.875236, 0.369526,  0.3041}, {{3, 4}, {0, 0}, {1, 3}, {3, 3}, {0, 2}, {3, 0}, {1,  3}, {0, 3}}, 0.853582, 0.455366, 0.653119, 0.876586, 0.767936,  2}, {-15.175884673826575817, -0.20481165704696246849, {7.43105e-6, 7.67716e-6, 0.999963, 0.0000178152,  3.69984e-6}}}, {{{-1., -0.55012, -0.0877742, -0.11249, -0.637815, -0.473678, -1., -0.650434, -0.932313, -0.636046, -0.621897, -0.769313, -1., -0.0313952, -0.713803, -0.728862, -0.378473, -0.416224, -1., -0.589039, -0.0983874, -0.682572, -0.269806, -0.0444043, -1., -1., -0.751507, -0.559158, -0.0377419, -0.602065, -0.201386, -1., -0.471384, -0.925252, -0.96425, -0.727708, -0.82095, -1., -0.992939, -0.328204, -0.105811, -0.0516368, -0.961544, -1., -0.614401, -0.376949, -0.673164, -0.545321, -0.0253621, -1., -1., -0.278562, -0.990592, -0.275515, -0.980958, -0.527056, -1., -0.431434, -0.237773, -0.378893, -0.325669, -0.96005, -1., -0.312521, -0.414643, -0.597961, -0.1391, -0.319581, -1., -0.0864393, -0.49215, -0.0874631, -0.358037, -0.472039, -1., -1., -0.1152, -0.414299, -0.812716, -0.446677, -0.836638, -1., -0.423707, -0.537201, -0.465719, -0.309583, -0.992273, -1., -0.299428, -0.086826, -0.983913, -0.0322231, -0.986908, -1., -0.672183, -0.385953, -0.893123, -0.667326, -0.585744, -1.},  0., {1., 0.106197, 0.19434, 0.69071, 0.886295, 0.221397, 1.,  0.608639, 0.503427, 0.332971, 0.0580354, 0.0323458, 1., 0.0406282,  0.79869, 0.367618, 0.0246188, 0.340057, 1., 0.885516, 0.351532,  0.056842, 0.326964, 0.557699, 1.}, {0.737484, 0.949965, 0.994291,  0.143443, 0.631287}, {0.755625, 0.30358, 0.257148, 0.40989,  0.146986}, {{1, 2}, {3, 1}, {2, 3}, {1, 2}, {2, 1}, {1, 0}, {3,  4}, {0, 3}}, 0.800153, 0.924177, 0.351855, 0.114641,  0.759525, -1}, {-13.314497038345293443,  0.59494427352496673944, {0.052751, 0.185585, 0.210893, 0.478297,  0.0724736}}}, {{{-1., -0.0157633, -0.909978, -0.580531, -0.76003, -0.367295, -1., -0.96682, -0.907495, -0.317729, -0.104779, -0.916785, -1., -0.901786, -0.461172, -0.736067, -0.672411, -0.205366, -1., -0.71832, -0.145957, -0.819397, -0.00551972, -0.642496, -1., -1., -0.497812, -0.934038, -0.765045, -0.767983, -0.482049, -1., -0.0240593, -0.184514, -0.00795338, -0.114754, -0.057239, -1., -0.277018, -0.690225, -0.00997441, -0.140454, -0.375232, -1., -0.229053, -0.273908, -0.468043, -0.169866, -0.510733, -1., -1., -0.127951, -0.648646, -0.164346, -0.868237, -0.630139, -1., -0.714608, -0.399301, -0.100254, -0.14809, -0.690549, -1., -0.214787, -0.0923006, -0.0333365, -0.63331, -0.937769, -1., -0.402076, -0.0233621, -0.492857, -0.562537, -0.173023, -1., -1., -0.749454, -0.0248136, -0.392672, -0.66229, -0.621504, -1., -0.376168, -0.228326, -0.794053, -0.991365, -0.661559, -1., -0.829024, -0.693799, -0.843275, -0.97101, -0.614237, -1., -0.601498, -0.809938, -0.3377, -0.676468, -0.199422, -1.},  0., {1., 0.213424, 0.155157, 0.886069, 0.973601, 0.962878, 1.,  0.17997, 0.278741, 0.63589, 0.584382, 0.556138, 1., 0.507066,  0.429943, 0.575747, 0.217697, 0.336091, 1., 0.123742, 0.419022,  0.188707, 0.950328, 0.72524, 1.}, {0.22896, 0.526407, 0.626796,  0.924662, 0.0155362}, {0.37125, 0.740726, 0.951061, 0.0526579,  0.191279}, {{1, 4}, {3, 4}, {3, 2}, {2, 1}, {1, 4}, {0, 3}, {3,  0}, {0, 2}}, 0.461985, 0.315171, 0.468276, 0.635141, 0.954919,  1}, {-13.404715516857457632,  0.13567747959958528038, {1.48237e-6, 0.999976, 0.0000121831,  1.02068e-6,  9.11539e-6}}}, {{{-1., -0.107471, -0.582556, -0.381172, -0.238514, -0.526493, -1., -0.771263, -0.3315, -0.963754, -0.755453, -0.297669, -1., -0.958296, -0.888416, -0.770989, -0.668919, -0.699022, -1., -0.839478, -0.823647, -0.860198, -0.161007, -0.154648, -1., -1., -0.291923, -0.49534, -0.115927, -0.0398762, -0.184452, -1., -0.912784, -0.734755, -0.801362, -0.657959, -0.141522, -1., -0.403255, -0.837608, -0.902506, -0.843852, -0.444959, -1., -0.949192, -0.131517, -0.174933, -0.745937, -0.109714, -1., -1., -0.30787, -0.314735, -0.58493, -0.955066, -0.0159474, -1., -0.819395, -0.469003, -0.91519, -0.831496, -0.906611, -1., -0.734249, -0.113828, -0.173537, -0.76509, -0.330994, -1., -0.27622, -0.27103, -0.921237, -0.886034, -0.327028, -1., -1., -0.139513, -0.746304, -0.140097, -0.217313, -0.831643, -1., -0.431569, -0.555167, -0.262247, -0.815696, -0.612173, -1., -0.0861642, -0.347058, -0.9842, -0.705562, -0.351916, -1., -0.23323, -0.810664, -0.940473, -0.0209221, -0.95701, -1.},  0., {1., 0.460367, 0.980764, 0.865112, 0.370018, 0.59988, 1.,  0.727068, 0.0052095, 0.587331, 0.431523, 0.158637, 1., 0.560377,  0.849578, 0.247219, 0.77081, 0.646541, 1., 0.196636, 0.231419,  0.476372, 0.998457, 0.429866, 1.}, {0.0420827, 0.416845,  0.0193789, 0.386876, 0.581716}, {0.436081, 0.154267, 0.0168583,  0.981835,  0.709013}, {{1, 3}, {3, 1}, {2, 4}, {3, 4}, {1, 4}, {0, 3}, {1,  3}, {2, 0}}, 0.149057, 0.429527, 0.550312, 0.550376, 0.58868,  1}, {-14.453339697224032905,  0.031148386176501442038, {1.75492e-6, 0.999989, 2.31002e-6,  3.8639e-6,  3.40746e-6}}}, {{{-1., -0.696907, -0.220434, -0.057861, -0.616687, -0.928326, -1., -0.696807, -0.0563178, -0.0465523, -0.970409, -0.113652, -1., -0.0756967, -0.433428, -0.552125, -0.549733, -0.229963, -1., -0.450286, -0.53396, -0.258745, -0.37902, -0.879814, -1., -1., -0.0842723, -0.809122, -0.967701, -0.459763, -0.387365, -1., -0.588687, -0.90984, -0.843076, -0.459039, -0.891881, -1., -0.853522, -0.796524, -0.48863, -0.778229, -0.777825, -1., -0.363096, -0.936506, -0.228497, -0.547862, -0.91281, -1., -1., -0.402545, -0.969751, -0.168841, -0.0329957, -0.318273, -1., -0.16063, -0.201141, -0.573233, -0.930908, -0.571942, -1., -0.291301, -0.730156, -0.471869, -0.680061, -0.43778, -1., -0.933632, -0.983238, -0.901832, -0.659955, -0.570536, -1., -1., -0.0467329, -0.673335, -0.112093, -0.657726, -0.644188, -1., -0.703584, -0.943252, -0.624731, -0.325914, -0.542955, -1., -0.742111, -0.0514979, -0.395007, -0.971012, -0.45081, -1., -0.321342, -0.923138, -0.290951, -0.01303, -0.38771, -1.},  0., {1., 0.0601004, 0.610881, 0.646925, 0.182826, 0.106833, 1.,  0.284216, 0.759017, 0.840552, 0.751021, 0.987801, 1., 0.702269,  0.465283, 0.0769354, 0.530755, 0.44438, 1., 0.516781, 0.471942,  0.501768, 0.895189, 0.838122, 1.}, {0.39508, 0.792719, 0.908219,  0.225832, 0.33498}, {0.181838, 0.261295, 0.0430064, 0.228146,  0.897622}, {{1, 3}, {0, 4}, {2, 4}, {2, 3}, {2, 0}, {0, 0}, {2,  2}, {2, 0}}, 0.502277, 0.202454, 0.477126, 0.909821, 0.800008,  1}, {-14.193919066717935768, -0.096612365360480156168, {0.0000153658, 0.99996, 9.89174e-6, 5.33722e-6,  9.26167e-6}}}, {{{-1., -0.59981, -0.620934, -0.644372, -0.779609, -0.0717519, -1., -0.122702, -0.539561, -0.617731, -0.466832, -0.915421, -1., -0.44778, -0.843564, -0.801812, -0.0972588, -0.709075, -1., -0.88657, -0.0299584, -0.99488, -0.211352, -0.0890243, -1., -1., -0.507084, -0.904701, -0.0113602, -0.826196, -0.907274, -1., -0.283767, -0.366989, -0.0465871, -0.835522, -0.161065, -1., -0.827428, -0.428856, -0.36869, -0.245644, -0.379648, -1., -0.585292, -0.566878, -0.148385, -0.670573, -0.698722, -1., -1., -0.53692, -0.153505, -0.459221, -0.609698, -0.0298358, -1., -0.248803, -0.447861, -0.783502, -0.122562, -0.965037, -1., -0.0808722, -0.736915, -0.28704, -0.803972, -0.253444, -1., -0.308059, -0.918349, -0.558328, -0.873797, -0.722767, -1., -1., -0.351471, -0.409944, -0.203224, -0.0240447, -0.814552, -1., -0.256439, -0.744003, -0.414347, -0.784716, -0.00763562, -1., -0.296142, -0.630845, -0.662154, -0.042599, -0.21527, -1., -0.89393, -0.375114, -0.238627, -0.961825, -0.585871, -1.}, 0., {1., 0.543235, 0.319701, 0.911971, 0.136895, 0.894706, 1.,  0.729645, 0.115195, 0.16094, 0.709258, 0.986084, 1., 0.859198,  0.575287, 0.493974, 0.99372, 0.15534, 1., 0.206132, 0.156128,  0.0363186, 0.37061, 0.100063, 1.}, {0.531242, 0.274946, 0.332435,  0.685934, 0.988007}, {0.955244, 0.420464, 0.549039, 0.0933007,  0.225599}, {{1, 0}, {0, 2}, {2, 3}, {2, 2}, {2, 4}, {1, 4}, {2,  3}, {1, 3}}, 0.305269, 0.388099, 0.384043, 0.239515, 0.446071,  3}, {-12.352007275095229693,  0.35177716650982293842, {2.73023e-7, 2.1221e-7,  5.99734e-7, 0.999998,  1.26542e-6}}}, {{{-1., -0.109931, -0.754205, -0.709269, -0.393321, -0.266058, -1., -0.790523, -0.0798788, -0.493383, -0.7973, -0.0654687, -1., -0.412314, -0.179317, -0.785308, -0.0207128, -0.832778, -1., -0.728356, -0.878608, -0.246312, -0.138047, -0.116455, -1., -1., -0.262651, -0.485827, -0.584117, -0.929266, -0.15272, -1., -0.731622, -0.874848, -0.535946, -0.886662, -0.941099, -1., -0.79497, -0.0425623, -0.0893616, -0.87563, -0.382656, -1., -0.863245, -0.304054, -0.854917, -0.549878, -0.134889, -1., -1., -0.425446, -0.608606, -0.411831, -0.018434, -0.162795, -1., -0.122779, -0.827714, -0.0891677, -0.0100741, -0.391157, -1., -0.952866, -0.553222, -0.123412, -0.450058, -0.157896, -1., -0.51066, -0.0340506, -0.574428, -0.77524, -0.647415, -1., -1., -0.729997, -0.71951, -0.225362, -0.512526, -0.304551, -1., -0.110905, -0.813531, -0.494092, -0.141756, -0.988126, -1., -0.985817, -0.404924, -0.131682, -0.596969, -0.0329516, -1., -0.851702, -0.00827006, -0.146911, -0.875056, -0.341043, -1.},  0., {1., 0.0257805, 0.427516, 0.900185, 0.306372, 0.755777, 1.,  0.147027, 0.125547, 0.818898, 0.0603279, 0.257931, 1., 0.939078,  0.31299, 0.202084, 0.246057, 0.924895, 1., 0.717915, 0.333766,  0.843026, 0.957847, 0.569617, 1.}, {0.342036, 0.989937, 0.832902,  0.91066, 0.316256}, {0.56242, 0.932718, 0.604288, 0.560479,  0.415394}, {{1, 4}, {3, 1}, {2, 1}, {1, 2}, {2, 0}, {0, 1}, {3,  0}, {1, 4}}, 0.807171, 0.785389, 0.500151, 0.157463,  0.868093, -1}, {-12.246895039352010397,  0.55281803368122786436, {0.446366, 0.527909, 0.000971106,  0.0111664,  0.0135878}}}, {{{-1., -0.701933, -0.0885942, -0.0568022, -0.245516, -0.0356995, -1., -0.93162, -0.014649, -0.815133, -0.377736, -0.921557, -1., -0.847551, -0.725793, -0.693992, -0.483977, -0.780269, -1., -0.330081, -0.254471, -0.899371, -0.58744, -0.11547, -1., -1., -0.754622, -0.0568336, -0.455533, -0.587869, -0.0526886, -1., -0.968239, -0.398731, -0.342353, -0.0169891, -0.0366196, -1., -0.384082, -0.52722, -0.639253, -0.115063, -0.536531, -1., -0.801427, -0.945261, -0.631086, -0.756261, -0.471346, -1., -1., -0.69079, -0.731715, -0.168821, -0.355876, -0.936169, -1., -0.674881, -0.713288, -0.768007, -0.88348, -0.706641, -1., -0.314556, -0.425654, -0.866491, -0.670022, -0.930474, -1., -0.898434, -0.227238, -0.554959, -0.393944, -0.0970077, -1., -1., -0.281977, -0.923874, -0.637683, -0.625662, -0.591186, -1., -0.192159, -0.468862, -0.269786, -0.655018, -0.517278, -1., -0.755574, -0.501779, -0.771538, -0.810637, -0.441018, -1., -0.0761248, -0.905047, -0.140615, -0.510543, -0.17769, -1.},  0., {1., 0.322191, 0.414344, 0.8834, 0.919317, 0.604168, 1.,  0.338218, 0.521083, 0.544979, 0.195354, 0.530377, 1., 0.989945,  0.814765, 0.850372, 0.047655, 0.745519, 1., 0.316544, 0.621909,  0.858292, 0.186537, 0.392669, 1.}, {0.526956, 0.998907, 0.69708,  0.570359, 0.204765}, {0.584562, 0.81368, 0.651042, 0.600597,  0.246345}, {{3, 2}, {2, 2}, {2, 0}, {0, 3}, {0, 3}, {3, 4}, {1,  0}, {0, 0}}, 0.292597, 0.106063, 0.405243, 0.715968,  0.302652, -1}, {-11.595641649330341797,  0.065091308885268505028, {0.48989, 0.0130642, 0.242296, 0.228354,  0.0263964}}}, {{{-1., -0.445129, -0.331687, -0.442867, -0.0252458, -0.0670383, -1., -0.189979, -0.629403, -0.417914, -0.593995, -0.188886, -1., -0.326483, -0.988273, -0.79876, -0.773448, -0.140163, -1., -0.639315, -0.399357, -0.0197925, -0.43276, -0.745378, -1., -1., -0.8046, -0.73576, -0.735412, -0.0366758, -0.359471, -1., -0.404073, -0.292545, -0.01143, -0.292433, -0.214094, -1., -0.663142, -0.593516, -0.698438, -0.0252082, -0.336659, -1., -0.605242, -0.899679, -0.25176, -0.196496, -0.965927, -1., -1., -0.500322, -0.231968, -0.763736, -0.220549, -0.695722, -1., -0.496208, -0.0283245, -0.183873, -0.336251, -0.0921347, -1., -0.735779, -0.172443, -0.0438186, -0.878041, -0.0726372, -1., -0.578928, -0.345381, -0.852833, -0.735978, -0.973686, -1., -1., -0.445702, -0.601073, -0.539482, -0.00775853, -0.94538, -1., -0.369105, -0.775746, -0.787209, -0.249658, -0.872897, -1., -0.747422, -0.603336, -0.913407, -0.780763, -0.0116423, -1., -0.430892, -0.869588, -0.902722, -0.939005, -0.851965, -1.},  0., {1., 0.475792, 0.950111, 0.796973, 0.121721, 0.921494, 1.,  0.551184, 0.336456, 0.12948, 0.866874, 0.920288, 1., 0.112202,  0.916689, 0.116533, 0.793186, 0.859623, 1., 0.520025, 0.0299397,  0.573948, 0.871265, 0.950917, 1.}, {0.899528, 0.47667, 0.810271,  0.802882, 0.423736}, {0.526559, 0.0132974, 0.681161, 0.502242,  0.975375}, {{2, 4}, {1, 2}, {3, 3}, {3, 0}, {2, 3}, {3, 3}, {2,  1}, {0, 4}}, 0.676842, 0.551681, 0.635367, 0.0550868, 0.56464,  3}, {-11.768327458386212041,  0.63380564066357982525, {1.39794e-7, 1.15732e-7,  1.59978e-7, 0.999999, 4.97081e-7}}}};

  for (unsigned int i = 0; i < sizeof (data) / sizeof (data[0]); ++i)
    {
      NormalDistribution prior_logbeta = 
        normal_distribution (data[i].inputs.logbetamean, 
                             data[i].inputs.logbetastddev);

      float logbeta = data[i].inputs.logbeta;
      float pz[5];
      float logpriorz[5];

      for (unsigned int j = 0; j < 5; ++j)
        {
          pz[j] = drand48 ();
          logpriorz[j] = log (data[i].inputs.priorz[j]);
        }

      float q = em (data[i].inputs.alpha,
                    5,
                    &logbeta,
                    pz,
                    logpriorz,
                    data[i].inputs.clamp,
                    data[i].inputs.ratings,
                    8,
                    &prior_logbeta.base,
                    0,
                    1000,
                    1e-7);

      assert (fabs (q - data[i].desired_output.newq) <=
              1e-3 * (1 + fabs (q) + fabs (data[i].desired_output.newq)) ||
              ( fprintf (stderr, "%u: %g ?= %g\n", i, q, data[i].desired_output.newq),
                0 ));
      
      assert (fabs (logbeta - data[i].desired_output.logbeta) <=
              1e-2 * (1 + fabs (logbeta) + fabs (data[i].desired_output.logbeta)) ||
              ( fprintf (stderr, "(oops logbeta ... but this is ok) %g ?= %g\n", logbeta, data[i].desired_output.logbeta),
                1 ));

      for (unsigned int j = 0; j < 5; ++j)
        {
          assert (fabs (pz[j] - data[i].desired_output.pz[j]) <=
                  1e-2 * (1 + fabs (pz[j]) + fabs (data[i].desired_output.pz[j])) ||
                  ( fprintf (stderr, "(oops pz[%u] ... but this is ok) %g ?= %g\n", j, pz[j], data[i].desired_output.pz[j]),
                    1 ));
        }
    }
}

int 
main (void)
{
  srand48 (69);

  test_em ();

  return 0;
}
